import os
import sys
import warnings
import pickle
import torch
import torch.nn as nn
import math
from data.data_loader import data_loader, get_deri_loader
from model.linear import Linear
from model.conv import CNNnet
from utils.essen_plot import plot_loss, plot_model_output, plot_eig_vs_var, plot_eig_vs_mean, plot_loss_landscape
from utils.save_path import CheckpointSaver, create_save_dir, save_code_and_config
from utils.derivatives_of_parameters import derivatives, get_hessian_eig
from utils.get_weight_matrix_and_pca import Get_weight_matrix_and_pca, get_loss_for_weight_matrix
import shutil
import platform
import model.vgg as vgg
import numpy as np
from config.config import parse_args
import copy
import random
from sklearn.decomposition import PCA
import torch.nn.functional as F
import time
from torch.autograd.variable import Variable
from matplotlib import pyplot as plt
warnings.filterwarnings("ignore")


os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'


def one_hot(x, class_count):
    return torch.eye(class_count)[x, :]


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_para(net):
    return [p.data for p in net.parameters()]


def get_diff_weights(weights, weights2):
    """ Produce a direction from 'weights' to 'weights2'."""
    return [w2 - w for (w, w2) in zip(weights, weights2)]


def nplist_to_tensor(nplist):
    """ Concatenate a list of numpy vectors into one tensor.

        Args:
            nplist: a list of numpy vectors, e.g., direction loaded from h5 file.

        Returns:
            concatnated 1D tensor
    """
    v = []
    for d in nplist:
        w = torch.tensor(d*np.float64(1.0))
        # Ignoreing the scalar values (w.dim() = 0).
        if w.dim() > 1:
            v.append(w.view(w.numel()))
        elif w.dim() == 1:
            v.append(w)
    return torch.cat(v)


def project_1D(w, d):
    """ Project vector w to vector d and get the length of the projection.

        Args:
            w: vectorized weights
            d: vectorized direction

        Returns:
            the projection scalar
    """
    assert len(w) == len(d), 'dimension does not match for w and '
    scale = np.dot(w, d)/np.linalg.norm(d)
    return scale


def project_2D(d, dx, dy, proj_method):
    """ Project vector d to the plane spanned by dx and dy.

        Args:
            d: vectorized weights
            dx: vectorized direction
            dy: vectorized direction
            proj_method: projection method
        Returns:
            x, y: the projection coordinates
    """

    if proj_method == 'cos':
        # when dx and dy are orthorgonal
        x = project_1D(d, dx)
        y = project_1D(d, dy)
    elif proj_method == 'lstsq':
        # solve the least squre problem: Ax = d
        A = np.vstack([dx.numpy(), dy.numpy()]).T
        [x, y] = np.linalg.lstsq(A, d.numpy())[0]

    return x, y


def get_model_lst(path_lst, args, model):
    model_lst_all = []
    for path in path_lst:
        path_model = '%smodel/' % (path)
        lst = os.listdir(path_model)
        lst.sort(key=sort_model)
        model_lst = []
        for i in lst:
            PATH = torch.load('%s%s' % (path_model, i),
                              map_location=args.device)
            model.load_state_dict(PATH['state_dict'][0])
            paralst = get_model_para(model)
            paralst = nplist_to_tensor(paralst)
            print(paralst)
            model_lst.append(paralst.cpu().numpy())
        model_lst_all.append(model_lst)
    return model_lst_all


def set_weights(net, weights, directions=None, step=None):
    """
        Overwrite the network's weights with a specified list of tensors
        or change weights along directions with a step size.
    """
    if directions is None:
        # You cannot specify a step length without a direction.
        for (p, w) in zip(net.parameters(), weights):
            p.data.copy_(w.type(type(p.data)))
    else:
        assert step is not None, 'If a direction is specified then step must be specified as well'

        if len(directions) == 2:
            dx = directions[0]
            dy = directions[1]
            changes = [d0*step[0] + d1*step[1] for (d0, d1) in zip(dx, dy)]
        else:
            changes = [d*step for d in directions[0]]

        for (p, w, d) in zip(net.parameters(), weights, changes):
            p.data = w + torch.Tensor(d).type(type(w))


def npvec_to_tensorlist(direction, params):
    """ Convert a numpy vector to a list of tensors with the same shape as "params".

        Args:
            direction: a list of numpy vectors, e.g., a direction loaded from h5 file.
            base: a list of parameter tensors from net

        Returns:
            a list of tensors with the same shape as base
    """
    if isinstance(params, list):
        w2 = copy.deepcopy(params)
        idx = 0
        for w in w2:
            w.copy_(torch.tensor(
                direction[idx:idx + w.numel()]).view(w.size()))
            idx += w.numel()
        assert(idx == len(direction))
        return w2
    else:
        s2 = []
        idx = 0
        for (k, w) in params.items():
            s2.append(torch.Tensor(
                direction[idx:idx + w.numel()]).view(w.size()))
            idx += w.numel()
        assert(idx == len(direction))
        return s2


def eval_loss(net, criterion, loader, use_cuda=False):
    """
    Evaluate the loss value for a given 'net' on the dataset provided by the loader.

    Args:
        net: the neural net model
        criterion: loss function
        loader: dataloader
        use_cuda: use cuda or not
    Returns:
        loss value and accuracy
    """
    correct = 0
    total_loss = 0
    total = 0  # number of samples
    num_batch = len(loader)

    if use_cuda:
        net.cuda()
    net.eval()

    with torch.no_grad():
        if isinstance(criterion, nn.CrossEntropyLoss):
            for batch_idx, (inputs, targets) in enumerate(loader):
                batch_size = inputs.size(0)
                total += batch_size
                inputs = Variable(inputs)
                targets = Variable(targets)
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                total_loss += loss.item()*batch_size
                _, predicted = torch.max(outputs.data, 1)
                correct += predicted.eq(targets).sum().item()

        elif isinstance(criterion, nn.MSELoss):
            for batch_idx, (inputs, targets) in enumerate(loader):
                batch_size = inputs.size(0)
                total += batch_size
                inputs = Variable(inputs)

                one_hot_targets = torch.FloatTensor(batch_size, 10).zero_()
                one_hot_targets = one_hot_targets.scatter_(
                    1, targets.view(batch_size, 1), 1.0)
                one_hot_targets = one_hot_targets.float()
                one_hot_targets = Variable(one_hot_targets)
                if use_cuda:
                    inputs, one_hot_targets = inputs.cuda(), one_hot_targets.cuda()
                outputs = F.softmax(net(inputs))
                loss = criterion(outputs, one_hot_targets)
                total_loss += loss.item()*batch_size
                _, predicted = torch.max(outputs.data, 1)
                correct += predicted.cpu().eq(targets).sum().item()

    return total_loss/total, 100.*correct/total


def plot_contour_trajectory(X, Y, loss_all, x_lst, y_lst, path):
    """2D contour + trajectory"""

    # plot contours

    fig = plt.figure()
    CS1 = plt.contour(X, Y, loss_all, levels=np.arange(0.1, 10, 0.5))
    # CS1 = plt.contour(X, Y, loss_all, levels=np.logspace(np.log10(loss_all.min()), np.log10(loss_all.max()), num=30))
    # print(np.logspace(loss_all.min(), loss_all.max(), num=30))
    plt.clabel(CS1, inline=1, fontsize=8)
    for i in range(len(x_lst)):

        plt.plot(x_lst[i], y_lst[i], marker='.')

    fig.savefig(path + '_2dcontour_proj_all_2.png', dpi=300,
                bbox_inches='tight', format='png')


def sort_model(name):
    return int(name.split('p')[1].split('.')[0])


def main():

    args, _ = parse_args()
    argsy = {}
    argsy['test_outputs'] = []
    print(args.hidden_layers_width)
    seed = np.random.randint(1000000)
    seed = 0
    seed_torch(seed)
    args.seed = seed
    R = {}
    R['loss_train'] = []
    R['exploration_para'] = []
    R['ini_trace'] = []
    R['iden_trace'] = []
    n_space = 20
    # R['acc_test'] = []
    args.device = torch.device("cuda:%s" % (
        args.device_rank) if torch.cuda.is_available() else "cpu")
    # args.device = 'cpu'
    # path_ori = '/home/xxx/data/saddle_points/test104/'
    args.model_name = '_'.join(map(str, args.hidden_layers_width))

    if args.network_type == 'linear':
        model = Linear(args.t, args.hidden_layers_width, args.input_dim,
                       args.output_dim, nn.ReLU(), args.initialization, args.dropout, args.dropout_pro, args.bias).to(args.device)
    if args.network_type == 'vgg':
        model = vgg.VGG9(dropout=False).to(args.device)
    print(model)
    path = '/home/xxx/data/sgd/test36/50_50/0.5/'
    path1 = '/home/xxx/data/sgd/test36/50_50/0.5/20210922053038757967/'
    path2 = '/home/xxx/data/sgd/test36/50_50/0.5/20210922061028446525/'
    pathlst = [path1, path2]
    para_lst = get_model_lst(pathlst, args, model)
    PATH = torch.load('%smodel/tmp299.pth.tar'%(path2),
                              map_location=args.device)
    model.load_state_dict(PATH['state_dict'][0])
    model = model.cpu()
    weight = get_model_para(model)

    print("Perform PCA on the models")
    pca = PCA(n_components=2)
    pca.fit(np.array(para_lst[0]+para_lst[1]))
    pc1 = np.array(pca.components_[0])
    pc2 = np.array(pca.components_[1])
    x_lst_all = []
    y_lst_all = []
    # for j in range(len(para_lst)):
        # data_lst=[]
    j=1
    x_lst = []
    y_lst = []
    for i in range(len(para_lst[0])):
        x, y = project_2D(para_lst[j][i]-para_lst[j][-1], pc1, pc2, 'cos')
        x_lst.append(x)
        y_lst.append(y)
        print(" (%.4f, %.4f)" % (x, y))
    x_lst_all.append(x_lst)
    y_lst_all.append(y_lst)
    print(x_lst_all)
    # path='/home/xxx/data/sgd/test34/50_50/0.5/20210920084913422847/'

    # para_lst_new=get_model_lst(path,args,model)

    # x_lst_new = []
    # y_lst_new = []
    # for i in range(len(para_lst_new)):
    #     x, y = project_2D(para_lst_new[i]-para_lst[-1], pc1, pc2, 'cos')
    #     x_lst_new.append(x)
    #     y_lst_new.append(y)
    #     print(" (%.4f, %.4f)" % (x, y))
    # data_lst.append([x_lst_new,y_lst_new])

    # path='/home/xxx/data/sgd/test34/50_50/0.5/20210920085055622816/'

    # para_lst_new=get_model_lst(path,args,model)

    # x_lst_new = []
    # y_lst_new = []
    # for i in range(len(para_lst_new)):
    #     x, y = project_2D(para_lst_new[i]-para_lst[-1], pc1, pc2, 'cos')
    #     x_lst_new.append(x)
    #     y_lst_new.append(y)
    #     print(" (%.4f, %.4f)" % (x, y))
    # data_lst.append([x_lst_new,y_lst_new])

    x_min, x_max = min(min(row) for row in x_lst_all), max(max(row)
                                                       for row in x_lst_all)
    y_min, y_max = min(min(row) for row in y_lst_all), max(max(row)
                                                       for row in y_lst_all)
    x_bound = [x_min-1/8*(x_max-x_min), x_max+1/8*(x_max-x_min)]
    y_bound = [y_min-1/8*(y_max-y_min), y_max+1/8*(y_max-y_min)]
    print(x_bound)
    x_direction = npvec_to_tensorlist(pc1, weight)
    y_direction = npvec_to_tensorlist(pc2, weight)
    X = np.linspace(x_bound[0], x_bound[1], n_space)
    Y = np.linspace(y_bound[0], y_bound[1], n_space)
    xcoord_mesh, ycoord_mesh = np.meshgrid(X, Y)
    train_loader, test_loader = data_loader(
        training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)

    loss_all = np.zeros_like(xcoord_mesh)
    model = model.cpu()

    for i in range(n_space):
        for j in range(n_space):
            print(i, j)
            set_weights(model, weight, directions=[x_direction, y_direction], step=[
                        xcoord_mesh[i, j], ycoord_mesh[i, j]])
            loss, _ = eval_loss(
                model, nn.CrossEntropyLoss(), train_loader, True)
            loss_all[i, j] = loss

    np.savetxt('%sloss.txt' % (path), loss_all)
    # loss_all=np.loadtxt('%sloss.txt'%(path))

    print(X,Y)

    plot_contour_trajectory(xcoord_mesh, ycoord_mesh, loss_all, x_lst_all, y_lst_all, path)
    # print(y_direction)

    quit()

    if args.data == '1Dpro':
        train_loader, test_loader, test_inputs, train_inputs, test_targets, train_targets = data_loader(
            training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)
        args.train_inputs, args.test_inputs = train_inputs, test_inputs
        args.train_targets, args.test_targets = train_targets, test_targets
    elif args.data == 'MNIST':
        train_loader, test_loader = data_loader(
            training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)
    # for i, data in enumerate(train_loader):
    #     if i <20:
    #         print("第 {} 个Batch \n{}".format(i, data))
    # sys.exit()
    # loss_fn = torch.nn.MSELoss(reduction='mean')
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    # optimizer = torch.optim.SGD(model.parameters())
    train(model, optimizer, loss_fn, R, train_loader, test_loader, args, argsy)


if __name__ == "__main__":
    main()
